use super::{const_buf, mut_buf, Fd, IoVec, IoVecMut};
use crate::runtime::{FdDel, FdWait};
use crate::{
    event::{POLLET, POLLIN, POLLOUT},
    Error, Result,
};
use core::ops::Deref;
use core::ptr;

/// 使用约束:
/// 1. 同一个Fd在一个工作线程只能有一个AioFd实例, 如果多个实例并发请求异步io事件，只会响应其中一个
///    这要求异步读写都只能在一个异步任务中完成.
/// 2. 如果必须将异步读写分离或者需要多个读写操作，需要利用Fd::clone复制Fd来实现.
/// 3. 或者Fd满足'static生命周期要求，可以利用hash调度策略，在不同工作线程中分别实现读写
#[repr(C)]
pub struct AioFd<'a> {
    fd: &'a Fd,
    index: usize,
}

unsafe impl Send for AioFd<'_> {}

impl Deref for AioFd<'_> {
    type Target = Fd;
    fn deref(&self) -> &Self::Target {
        self.fd
    }
}

impl<'a> AioFd<'a> {
    pub fn new(fd: &'a Fd) -> Self {
        Self { fd, index: 0 }
    }

    /// 同一个Task中，Fd的其他AioFd实例产生的cookie可以加速AioFd::wait_readable/wait_writable/wait操作.
    /// 如果非本Fd对应的cookie，无加速效果，对功能无影响.
    pub fn new_with(fd: &'a Fd, cookie: usize) -> Self {
        Self { fd, index: cookie }
    }

    /// 返回值可用于AioFd::new_with
    /// cookie值可以加速AioFd::wait_readable/wait_writable/wait操作.
    /// 在task中，cookie和Fd是一一对应的, 如果将这里的返回值用于其他Fd创建AioFd，将无加速效果.
    pub fn wait_cookie(&self) -> usize {
        self.index
    }
}

impl AioFd<'_> {
    /// 如果指定多个事件，可能因为部分事件到来而返回
    #[deprecated(note = "should instead use wait_readable or wait_writable")]
    pub async fn wait(&mut self, events: u32) -> Result<()> {
        FdWait::new(self.fd(), &mut self.index, events | POLLET).await
    }

    /// 等待可读事件的到来
    pub async fn wait_readable(&mut self) -> Result<()> {
        // 底层有两种实现，FdWait支持POLLET，FdWaitOnce是否使用POLLET没有影响
        FdWait::new(self.fd(), &mut self.index, POLLIN | POLLET).await
    }

    /// 等待可写事件的到来
    pub async fn wait_writable(&mut self) -> Result<()> {
        // 底层有两种实现，FdWait支持POLLET，FdWaitOnce是否使用POLLET没有影响
        FdWait::new(self.fd(), &mut self.index, POLLOUT | POLLET).await
    }

    /// 向底层取消注册可读可写的io事件. 如果再也不需要，或者准备关闭fd之前，应该调用.
    /// 如果未调用此接口，就关闭了Fd，此fd的值可能被系统重用，后续新建的Fd将可能无法接收到Io事件.
    /// 主要的原因是只能通过异步函数才能取消注册，而Rust析构无法支持异步，必须程序员显示调用.
    pub async fn wait_none(&mut self) {
        FdDel::new(self.fd(), self.index).await
    }

    /// 收到至少一个字节的数据就返回，如果返回Ok(0)说明对端断开连接.
    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        self.do_read(buf, false, |fd, buf| unsafe {
            libc::read(fd, mut_buf(buf), buf.len())
        })
        .await
    }

    /// 读取的数据填满buf后才返回，除非对端断链，此时返回Ok(size)将小于buf的长度.
    pub async fn read_all(&mut self, buf: &mut [u8]) -> Result<usize> {
        self.do_read(buf, true, |fd, buf| unsafe {
            libc::read(fd, mut_buf(buf), buf.len())
        })
        .await
    }

    pub(crate) async fn do_read<F>(
        &mut self,
        mut buf: &mut [u8],
        all: bool,
        mut f: F,
    ) -> Result<usize>
    where
        F: FnMut(i32, &mut [u8]) -> isize,
    {
        let mut bytes = 0;
        loop {
            let ret = f(self.fd.fd, buf);
            // 频繁调用场景，但rust无法指定分支预测功能
            #[allow(clippy::comparison_chain)]
            if ret > 0 {
                let n = ret as usize;
                bytes += n;
                if n == buf.len() {
                    return Ok(bytes);
                }
                buf = &mut buf[n..];
            } else if ret == 0 {
                return Ok(bytes);
            } else {
                let e = Error::last();
                match e.errno {
                    libc::EAGAIN if all || bytes == 0 => self.wait_readable().await?,
                    libc::EAGAIN => return Ok(bytes),
                    libc::EINTR => continue,
                    _ if bytes > 0 => return Ok(bytes),
                    _ => return Err(e),
                }
            }
        }
    }

    /// 发送至少一个字节的数据就返回.
    pub async fn write(&mut self, buf: &[u8]) -> Result<usize> {
        self.do_write(buf, false, |fd, buf| unsafe {
            libc::write(fd, const_buf(buf), buf.len())
        })
        .await
    }
    /// 将buf的数据全部发送出去后才返回. 也可能因为连接断开导致发送部分数据.
    pub async fn write_all(&mut self, buf: &[u8]) -> Result<usize> {
        self.do_write(buf, true, |fd, buf| unsafe {
            libc::write(fd, const_buf(buf), buf.len())
        })
        .await
    }

    pub(crate) async fn do_write<F>(&mut self, mut buf: &[u8], all: bool, mut f: F) -> Result<usize>
    where
        F: FnMut(i32, &[u8]) -> isize,
    {
        let mut bytes = 0;
        loop {
            let ret = f(self.fd.fd, buf);
            if ret >= 0 {
                let n = ret as usize;
                bytes += n;
                if n == buf.len() {
                    return Ok(bytes);
                }
                buf = &buf[n..];
            } else {
                let e = Error::last();
                match e.errno {
                    libc::EAGAIN if all || bytes == 0 => self.wait_writable().await?,
                    libc::EAGAIN => return Ok(bytes),
                    libc::EINTR => continue,
                    _ if bytes > 0 => return Ok(bytes),
                    _ => return Err(e),
                }
            }
        }
    }

    /// 读取到至少一个字节就返回, 除非对端断链.
    /// off: 表示buf的字节偏移
    pub async fn readv(&mut self, buf: &[&mut [u8]], off: usize) -> Result<usize> {
        self.do_readv(buf, off, false, |fd, piov, iovcnt| unsafe {
            libc::readv(fd, piov, iovcnt)
        })
        .await
    }

    /// 读取的数据填满buf后才返回，除非对端断链，此时返回Ok(size)将小于buf的长度.
    /// off: 表示buf的字节偏移
    pub async fn readv_all(&mut self, buf: &[&mut [u8]], off: usize) -> Result<usize> {
        self.do_readv(buf, off, true, |fd, piov, iovcnt| unsafe {
            libc::readv(fd, piov, iovcnt)
        })
        .await
    }

    pub(crate) async fn do_readv<F>(
        &mut self,
        buf: &[&mut [u8]],
        off: usize,
        all: bool,
        mut f: F,
    ) -> Result<usize>
    where
        F: FnMut(i32, *const libc::iovec, i32) -> isize,
    {
        let iovec = &mut [libc::iovec {
            iov_base: ptr::null_mut(),
            iov_len: 0,
        }; 64];
        let mut buf = IoVecMut::new(buf, off);
        let mut bytes = 0_usize;
        'top: while let Some((piovec, iovcnt, size)) = buf.to_iovec(iovec) {
            loop {
                let ret = f(self.fd.fd(), piovec, iovcnt);
                if ret > 0 {
                    bytes += ret as usize;
                    if ret as usize == size {
                        buf = buf.next_iovec(iovcnt as usize);
                    } else {
                        buf = buf.next_bytes(ret as usize);
                    }
                    continue 'top;
                } else if ret == 0 {
                    return Ok(bytes);
                } else {
                    let errno = Error::last().errno;
                    match errno {
                        hierr::EAGAIN if all || bytes == 0 => self.wait_readable().await?,
                        hierr::EAGAIN => return Ok(bytes),
                        hierr::EINTR => continue,
                        _ if bytes > 0 => return Ok(bytes),
                        _ => return Err(Error::new(errno)),
                    }
                }
            }
        }
        Ok(bytes)
    }

    /// 至少发送一个字节才返回
    /// off: 表示buf的字节偏移
    pub async fn writev(&mut self, buf: &[&[u8]], off: usize) -> Result<usize> {
        self.do_writev(buf, off, false, |fd, piov, iovcnt| unsafe {
            libc::writev(fd, piov, iovcnt)
        })
        .await
    }

    /// 将buf的数据全部发送出去后才返回. 也可能因为连接断开导致发送部分数据.
    /// off: 表示buf的字节偏移
    pub async fn writev_all(&mut self, buf: &[&[u8]], off: usize) -> Result<usize> {
        self.do_writev(buf, off, true, |fd, piov, iovcnt| unsafe {
            libc::writev(fd, piov, iovcnt)
        })
        .await
    }

    pub(crate) async fn do_writev<F>(
        &mut self,
        buf: &[&[u8]],
        off: usize,
        all: bool,
        mut f: F,
    ) -> Result<usize>
    where
        F: FnMut(i32, *const libc::iovec, i32) -> isize,
    {
        let iovec = &mut [libc::iovec {
            iov_base: ptr::null_mut(),
            iov_len: 0,
        }; 64];
        let mut buf = IoVec::new(buf, off);
        let mut bytes = 0_usize;
        'top: while let Some((piovec, iovcnt, size)) = buf.to_iovec(iovec) {
            loop {
                let ret = f(self.fd.fd(), piovec, iovcnt);
                if ret > 0 {
                    bytes += ret as usize;
                    if ret as usize == size {
                        buf = buf.next_iovec(iovcnt as usize);
                    } else {
                        buf = buf.next_bytes(ret as usize);
                    }
                    continue 'top;
                } else {
                    let errno = Error::last().errno;
                    match errno {
                        hierr::EAGAIN if all || bytes == 0 => self.wait_writable().await?,
                        hierr::EAGAIN => return Ok(bytes),
                        hierr::EINTR => continue,
                        _ if bytes > 0 => return Ok(bytes),
                        _ => return Err(Error::new(errno)),
                    }
                }
            }
        }
        Ok(bytes)
    }

    /// 发送至少一个字节的数据就返回.
    pub async fn sendfile(&mut self, in_fd: i32, off: usize, count: usize) -> Result<usize> {
        self.do_sendfile(in_fd, off, count, false).await
    }

    /// 将数据全部发送出去后才返回. 也可能因为连接断开，只发送了部分数据.
    pub async fn sendfile_all(&mut self, in_fd: i32, off: usize, count: usize) -> Result<usize> {
        self.do_sendfile(in_fd, off, count, true).await
    }

    async fn do_sendfile(
        &mut self,
        in_fd: i32,
        off: usize,
        count: usize,
        all: bool,
    ) -> Result<usize> {
        let mut off = off as i64;
        let mut len = count;
        let mut bytes = 0_usize;
        loop {
            let ret = unsafe { libc::sendfile(self.fd.fd, in_fd, &mut off, len) };
            if ret > 0 {
                bytes += ret as usize;
                if len == ret as usize {
                    return Ok(bytes);
                } else {
                    len -= ret as usize;
                }
            } else {
                let e = Error::last();
                match e.errno {
                    hierr::EAGAIN if all || bytes == 0 => self.wait_writable().await?,
                    hierr::EAGAIN => return Ok(bytes),
                    hierr::EINTR => continue,
                    _ if bytes > 0 => return Ok(bytes),
                    _ => return Err(e),
                }
            }
        }
    }
}
